# -*- coding: utf-8 -*-

import numpy
import redis
import pandas as pd
import math
import matplotlib
import talib

matplotlib.use("WXAgg", warn=True)  # 这个要紧跟在 import matplotlib 之后，而且必须安装了 wxpython 2.8 才行。
import matplotlib.pyplot as pyplot
from matplotlib.ticker import FixedLocator, MultipleLocator, LogLocator, FuncFormatter, NullFormatter


class KLine(object):
    def __init__(self):
        # 一条 K 线的宽度在 X 轴上所占距离（英寸）
        self._xfactor = 10.0 / 230.0

        # Y 轴上每一个距离单位的长度（英寸），这个单位距离是线性坐标和对数坐标通用的
        self._yfactor = 0.3

        # 底数，取得小一点，比较接近 1。股价 3 元到 4 元之间有大约 3 个单位距离
        self._expbase = 1.1

        #
        self._length = 0

        #
        self._rect_1 = (0, 0, 0, 0)
        self._rect_2 = (0, 0, 0, 0)
        #
        self._ylowlim_price = 0
        self._yhighlim_price = 0

        #
        self._orde_data = []
        self._orde_fangxiang = []
        self._orde_price = []

    def set_price(self, orde_data, orde_fangxiang, orde_price):
        self._orde_data = orde_data
        self._orde_fangxiang = orde_fangxiang
        self._orde_price = orde_price

    @staticmethod
    def convertbookid(order_book_id):
        codeid = order_book_id[:6]
        marketid = order_book_id[7:]
        if marketid == 'XSHE':
            marketid = 'KX:HISDAYKLINE:0000'
        elif marketid == 'XSHG':
            marketid = 'KX:HISDAYKLINE:0100'
        codeid = marketid + codeid + '0000000000000000000000'
        return codeid

    def get_kline_from_redis(self, order_book_id,
                             begindate, enddate):
        client = redis.Redis(host='120.55.149.71',
                             port=16379, db=0, password='redisserver')

        self.convertbookid(order_book_id)
        klinedata = client.zrevrangebyscore(order_book_id, begindate, enddate)

        dataset = list()
        highset = list()
        openset = list()
        lowset = list()
        closeset = list()
        volume = list()
        indexlist = list()
        avg = list()
        for kline in klinedata:
            tmp = str(kline)
            new = tmp[:-1]
            tmplist = new.split(' ', len(tmp))
            dataset.insert(0, tmplist[0])
            openset.insert(0, float(tmplist[1]))
            highset.insert(0, float(tmplist[2]))
            lowset.insert(0, float(tmplist[3]))
            closeset.insert(0, float(tmplist[4]))
            volume.insert(0, float(tmplist[6]))
            avg.insert(0, float(tmplist[7]) / float(tmplist[6]) / 100)
            indexlist = pd.Index(pd.Timestamp(str(d)) for d in dataset)
        return {
            'open': openset[:],
            'close': closeset[:],
            'high': highset[:],
            'low': lowset[:],
            'volume': volume[:],
            'avg': avg[:],
            'date': indexlist[:]
        }

    def get_kline_data(self, order_book_id, begindate, enddate):
        data = self.get_kline_from_redis(self.convertbookid(order_book_id),
                                         begindate, enddate)
        indexlist = pd.Index(pd.Timestamp(str(d)) for d in data['date'])
        return pd.DataFrame(data=data, index=indexlist)

    @staticmethod
    def get_idx(date, pdata):
        i = 0
        for i, orDate in enumerate(pdata['date']):
            if str(orDate)[0:10] == str(date)[0:10]:
                break
        return i

    def plot_figure(self, pdata):
        #   计算图片的尺寸（单位英寸）
        #   注意：Python2 里面， "1 / 10" 结果是 0, 必须写成 "1.0 / 10" 才会得到 0.1
        # ==================================================================================================================================================
        self._length = len(pdata['date'])  # 所有数据的长度，就是天数

        highest_price = max(pdata['high'])  # 最高价
        lowest_price = min([plow for plow in pdata['low'] if plow is not None])  # 最低价

        self._yhighlim_price = round(highest_price + 1, 2)  # K线子图 Y 轴最大坐标
        self._ylowlim_price = round(lowest_price - 1, 2)  # K线子图 Y 轴最小坐标

        # 价格在 Y 轴上的 “份数”
        ymulti_price = math.log(self._yhighlim_price, self._expbase) - math.log(self._ylowlim_price, self._expbase)

        ymulti_vol = 3.0  # 成交量部分在 Y 轴所占的 “份数”
        ymulti_top = 0.2  # 顶部空白区域在 Y 轴所占的 “份数”
        ymulti_bot = 0.8  # 底部空白区域在 Y 轴所占的 “份数”

        xmulti_left = 5.0  # 左侧空白区域所占的 “份数”
        xmulti_right = 3.0  # 右侧空白区域所占的 “份数”

        xmulti_all = self._length + xmulti_left + xmulti_right
        xlen_fig = xmulti_all * self._xfactor  # 整个 Figure 的宽度
        ymulti_all = ymulti_price + ymulti_vol + ymulti_top + ymulti_bot
        ylen_fig = ymulti_all * self._yfactor  # 整个 Figure 的高度

        self._rect_1 = (xmulti_left / xmulti_all, (ymulti_bot + ymulti_vol) / ymulti_all, self._length / xmulti_all,
                        ymulti_price / ymulti_all)  # K线图部分
        self._rect_2 = (xmulti_left / xmulti_all, ymulti_bot / ymulti_all,
                        self._length / xmulti_all, ymulti_vol / ymulti_all)  # 成交量部分

        #   建立 Figure 对象
        # ==================================================================================================================================================
        figfacecolor = 'white'
        figedgecolor = 'black'
        # figdpi = 150
        figlinewidth = 1.0

        figobj = pyplot.figure(figsize=(xlen_fig, ylen_fig), facecolor=figfacecolor, edgecolor=figedgecolor,
                               linewidth=figlinewidth)  # Figure 对象

        return figobj

    # def plot_volumn(self,pdata):

    def plotkline(self, pdata):

        figobj = self.plot_figure(pdata)
        # ==================================================================================================================================================
        # ==================================================================================================================================================
        # =======    成交量部分
        # ==================================================================================================================================================
        # ==================================================================================================================================================

        #   添加 Axes 对象
        # ==================================================================================================================================================
        axes_2 = figobj.add_axes(self._rect_2, axis_bgcolor='black')
        axes_2.set_axisbelow(True)  # 网格线放在底层

        #   改变坐标线的颜色
        # ==================================================================================================================================================
        for child in axes_2.get_children():
            if isinstance(child, matplotlib.spines.Spine):
                child.set_color('lightblue')

        # 得到 X 轴 和 Y 轴 的两个 Axis 对象
        # ==================================================================================================================================================
        xaxis_2 = axes_2.get_xaxis()
        yaxis_2 = axes_2.get_yaxis()

        #   设置两个坐标轴上的 grid
        # ==================================================================================================================================================
        xaxis_2.grid(True, 'major', color='0.3', linestyle='solid', linewidth=0.2)
        xaxis_2.grid(True, 'minor', color='0.3', linestyle='dotted', linewidth=0.1)

        yaxis_2.grid(True, 'major', color='0.3', linestyle='solid', linewidth=0.2)
        yaxis_2.grid(True, 'minor', color='0.3', linestyle='dotted', linewidth=0.1)

        # ==================================================================================================================================================
        # =======    绘图
        # ==================================================================================================================================================
        xindex = numpy.arange(self._length)  # X 轴上的 index，一个辅助数据

        zipoc = zip(pdata['open'], pdata['close'])
        up = numpy.array([True if po < pc and po is not None else False for po, pc in zipoc])  # 标示出该天股价日内上涨的一个序列
        down = numpy.array([True if po > pc and po is not None else False for po, pc in zipoc])  # 标示出该天股价日内下跌的一个序列
        side = numpy.array([True if po == pc and po is not None else False for po, pc in zipoc])  # 标示出该天股价日内走平的一个序列

        volume = pdata['volume']
        rarray_vol = numpy.array(volume)
        volzeros = numpy.zeros(self._length)  # 辅助数据

        # XXX: 如果 up/down/side 各项全部为 False，那么 vlines() 会报错。
        if True in up:
            axes_2.vlines(xindex[up], volzeros[up], rarray_vol[up], color='red', linewidth=3.0, label='_nolegend_')
        if True in down:
            axes_2.vlines(xindex[down], volzeros[down], rarray_vol[down], color='green', linewidth=3.0,
                          label='_nolegend_')
        if True in side:
            axes_2.vlines(xindex[side], volzeros[side], rarray_vol[side], color='0.7', linewidth=3.0,
                          label='_nolegend_')

        # 设定 X 轴坐标的范围
        # ==================================================================================================================================================
        axes_2.set_xlim(-1, self._length)

        #   设定 X 轴上的坐标
        # ==================================================================================================================================================
        datelist = []
        for dstr in pdata['date']:
            datelist.append(pd.to_datetime(dstr))
        # 确定 X 轴的 MajorLocator
        mdindex = []  # 每个月第一个交易日在所有日期列表中的 index
        years = set([d.year for d in datelist])  # 所有的交易年份

        for y in sorted(years):
            months = set([d.month for d in datelist if d.year == y])  # 当年所有的交易月份
            for m in sorted(months):
                monthday = min([dt for dt in datelist if dt.year == y and dt.month == m])  # 当月的第一个交易日
                mdindex.append(datelist.index(monthday))

        x_major_locator = FixedLocator(numpy.array(mdindex))

        # 确定 X 轴的 MinorLocator
        wdindex = []  # 每周第一个交易日在所有日期列表中的 index
        for d in datelist:
            if d.weekday() == 0:
                wdindex.append(datelist.index(d))

        x_minor_locator = FixedLocator(numpy.array(wdindex))

        # 确定 X 轴的 MajorFormatter 和 MinorFormatter
        def x_major_formatter_2(idx, pos=None):
            return datelist[int(idx)].strftime('%Y-%m-%d')

        def x_minor_formatter_2(idx, pos=None):
            return datelist[int(idx)].strftime('%m-%d')

        x_major_formatter = FuncFormatter(x_major_formatter_2)
        x_minor_formatter = FuncFormatter(x_minor_formatter_2)

        # 设定 X 轴的 Locator 和 Formatter
        xaxis_2.set_major_locator(x_major_locator)
        xaxis_2.set_major_formatter(x_major_formatter)

        xaxis_2.set_minor_locator(x_minor_locator)
        xaxis_2.set_minor_formatter(x_minor_formatter)

        # 设定 X 轴主要坐标点与辅助坐标点的样式
        for malabel in axes_2.get_xticklabels(minor=False):
            malabel.set_fontsize(3)
            malabel.set_horizontalalignment('right')
            malabel.set_rotation('30')

        for milabel in axes_2.get_xticklabels(minor=True):
            milabel.set_fontsize(2)
            milabel.set_horizontalalignment('right')
            milabel.set_rotation('30')

        # 设定 Y 轴坐标的范围
        # ==================================================================================================================================================
        maxvol = max(volume)  # 注意是 int 类型
        # minvol = min(volume)

        axes_2.set_ylim(0, maxvol)

        #   设定 Y 轴上的坐标
        # ==================================================================================================================================================
        max_vollen = len(str(maxvol))
        # min_vollen = len(str(minvol))
        y_major_locator_2 = MultipleLocator(10 ** (max_vollen - 1))
        # y_minor_locator_2 = MultipleLocator(10 ** (min_vollen - 1))
        y_minor_locator_2 = MultipleLocator((10 ** (max_vollen - 2)) * 5)

        # 确定 Y 轴的 MajorFormatter
        #   def y_major_formatter_2(num, pos=None):
        #       numtable= {'1':u'一', '2':u'二', '3':u'三', '4':u'四', '5':u'五', '6':u'六', '7':u'七', '8':u'八', '9':u'九', }
        #       dimtable= {3:u'百', 4:u'千', 5:u'万', 6:u'十万', 7:u'百万', 8:u'千万', 9:u'亿', 10:u'十亿', 11:u'百亿'}
        #       return numtable[str(num)[0]] + dimtable[vollen] if num != 0 else '0'

        def y_major_formatter_2(num, pos=None):
            return int(num)

        y_major_formatter_2 = FuncFormatter(y_major_formatter_2)

        # 确定 Y 轴的 MinorFormatter
        def y_minor_formatter_2(num, pos=None):
            return int(num)

        # y_minor_formatter_2 = FuncFormatter(y_minor_formatter_2)
        y_minor_formatter_2 = NullFormatter()

        # 设定 X 轴的 Locator 和 Formatter
        yaxis_2.set_major_locator(y_major_locator_2)
        yaxis_2.set_major_formatter(y_major_formatter_2)

        yaxis_2.set_minor_locator(y_minor_locator_2)
        yaxis_2.set_minor_formatter(y_minor_formatter_2)

        # 设定 Y 轴主要坐标点与辅助坐标点的样式
        for malab in axes_2.get_yticklabels(minor=False):
            malab.set_fontsize(3)

        for milab in axes_2.get_yticklabels(minor=True):
            milab.set_fontsize(2)

        # ==================================================================================================================================================
        # ==================================================================================================================================================
        # =======    K 线图部分
        # ==================================================================================================================================================
        # ==================================================================================================================================================

        #   添加 Axes 对象
        # ==================================================================================================================================================
        axes_1 = figobj.add_axes(self._rect_1, axis_bgcolor='black', sharex=axes_2)
        axes_1.set_axisbelow(True)  # 网格线放在底层

        axes_1.set_yscale('log', basey=self._expbase)  # 使用对数坐标

        # 改变坐标线的颜色
        # ==================================================================================================================================================
        for child in axes_1.get_children():
            if isinstance(child, matplotlib.spines.Spine):
                child.set_color('lightblue')

        # 得到 X 轴 和 Y 轴 的两个 Axis 对象
        # ==================================================================================================================================================
        xaxis_1 = axes_1.get_xaxis()
        yaxis_1 = axes_1.get_yaxis()

        #   设置两个坐标轴上的 grid
        # ==================================================================================================================================================
        xaxis_1.grid(True, 'major', color='0.3', linestyle='solid', linewidth=0.2)
        xaxis_1.grid(True, 'minor', color='0.3', linestyle='dotted', linewidth=0.1)

        yaxis_1.grid(True, 'major', color='0.3', linestyle='solid', linewidth=0.2)
        yaxis_1.grid(True, 'minor', color='0.3', linestyle='dotted', linewidth=0.1)

        # ==================================================================================================================================================
        # =======    绘图
        # ==================================================================================================================================================

        #   绘制 K 线部分
        # ==================================================================================================================================================
        rarray_open = numpy.array(pdata['open'])
        rarray_close = numpy.array(pdata['close'])
        rarray_high = numpy.array(pdata['high'])
        rarray_low = numpy.array(pdata['low'])

        # XXX: 如果 up, down, side 里有一个全部为 False 组成，那么 vlines() 会报错。
        if True in up:
            axes_1.vlines(xindex[up], rarray_low[up], rarray_high[up], color='red', linewidth=0.6, label='_nolegend_')
            axes_1.vlines(xindex[up], rarray_open[up], rarray_close[up], color='red', linewidth=3.0, label='_nolegend_')
            # for x, y in xindex[up], rarray_high[up]:

        if True in down:
            axes_1.vlines(xindex[down], rarray_low[down], rarray_high[down], color='green', linewidth=0.6,
                          label='_nolegend_')
            axes_1.vlines(xindex[down], rarray_open[down], rarray_close[down], color='green', linewidth=3.0,
                          label='_nolegend_')
        if True in side:
            axes_1.vlines(xindex[side], rarray_low[side], rarray_high[side], color='0.7', linewidth=0.6,
                          label='_nolegend_')
            axes_1.vlines(xindex[side], rarray_open[side] + 0.02, rarray_close[side], color='0.7', linewidth=3.0,
                          label='_nolegend_')
        # add price
        # ==================================================================================================================================================
        '''
        axes_1.annotate(str(rarray_high[1].round(2)), xy=(xindex[1], rarray_high[1]),
                        xytext=(xindex[1], rarray_high[1] + 1),
                        color='white', fontsize=4.0,
                        arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2", color='white', linewidth=0.4))

        axes_1.annotate(str(rarray_low[10].round(2)), xy=(xindex[10], rarray_low[10]),
                        xytext=(xindex[10], rarray_low[10] - 1),
                        color='w', fontsize=4.0,
                        arrowprops=dict(arrowstyle="->", color='w', linewidth=0.4))
        '''
        for i, itemdate in enumerate(self._orde_data):

            itemfx = self._orde_fangxiang[i]
            itemprice = self._orde_price[i]
            price_idx = self.get_idx(itemdate, pdata)

            if itemfx > 0:
                axes_1.annotate(str(itemprice.round(2)), xy=(xindex[price_idx], itemprice),
                                xytext=(xindex[price_idx], rarray_low[price_idx] - 1),
                                color='red', fontsize=14.0,
                                arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2", color='purple',
                                                linewidth=1.5))
            elif itemfx < 0:
                axes_1.annotate(str(itemprice.round(2)), xy=(xindex[price_idx], itemprice),
                                xytext=(xindex[price_idx], rarray_high[price_idx] + 1),
                                color='g', fontsize=14.0,
                                arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2", color='w',
                                                linewidth=1.5))

        # 绘制均线部分
        # ==================================================================================================================================================

        rarray_1dayave = numpy.array(pdata['avg'])

        temp = talib.SMA(rarray_1dayave, 5)
        rarray_5dayave = numpy.array(temp)
        temp = talib.SMA(rarray_1dayave, 10)
        rarray_10dayave = numpy.array(temp)
        temp = talib.SMA(rarray_1dayave, 20)
        rarray_20dayave = numpy.array(temp)
        temp = talib.SMA(rarray_1dayave, 60)
        rarray_60dayave = numpy.array(temp)

        axes_1.plot(xindex, rarray_5dayave, 'o-', color='white', linewidth=0.5, markersize=0.7, markeredgecolor='white',
                    markeredgewidth=0.1)  # 5日加权均线
        axes_1.plot(xindex, rarray_10dayave, 'o-', color='yellow', linewidth=0.5, markersize=0.7,
                    markeredgecolor='yellow',
                    markeredgewidth=0.1)  # 10日均线
        axes_1.plot(xindex, rarray_20dayave, 'o-', color='purple', linewidth=0.5, markersize=0.7,
                    markeredgecolor='green',
                    markeredgewidth=0.1)  # 20日均线
        axes_1.plot(xindex, rarray_60dayave, 'o-', color='green', linewidth=0.5, markersize=0.7,
                    markeredgecolor='green',
                    markeredgewidth=0.1)  # 60日均线

        #   设定 X 轴坐标的范围
        # ==================================================================================================================================================
        axes_1.set_xlim(-1, self._length)

        #   先设置 label 位置，再将 X 轴上的坐标设为不可见。因为与 成交量子图 共用 X 轴
        # ==================================================================================================================================================

        # 设定 X 轴的 Locator 和 Formatter
        xaxis_1.set_major_locator(x_major_locator)
        xaxis_1.set_major_formatter(x_major_formatter)

        xaxis_1.set_minor_locator(x_minor_locator)
        xaxis_1.set_minor_formatter(x_minor_formatter)

        # 将 X 轴上的坐标设为不可见。
        for malab in axes_1.get_xticklabels(minor=False):
            malab.set_visible(False)

        for milab in axes_1.get_xticklabels(minor=True):
            milab.set_visible(False)

        # 用这一段效果也一样
        #   pyplot.setp(axes_1.get_xticklabels(minor=False), visible=False)
        #   pyplot.setp(axes_1.get_xticklabels(minor=True), visible=False)

        # 设定 Y 轴坐标的范围
        #  ==================================================================================================================================================
        axes_1.set_ylim(self._ylowlim_price, self._yhighlim_price)

        #   设定 Y 轴上的坐标
        # ==================================================================================================================================================

        # 主要坐标点
        #  -----------------------------------------------------
        y_major_locator_1 = LogLocator(base=self._expbase)

        y_major_formatter_1 = NullFormatter()

        # 设定 Y 轴的 Locator 和 Formatter
        yaxis_1.set_major_locator(y_major_locator_1)
        yaxis_1.set_major_formatter(y_major_formatter_1)

        # 设定 Y 轴主要坐标点与辅助坐标点的样式
        #   for mal in axes_1.get_yticklabels(minor=False):
        #       mal.set_fontsize(3)

        #   辅助坐标点
        # -----------------------------------------------------
        minorticks = range(int(self._ylowlim_price), int(self._yhighlim_price) + 1, 1)

        y_minor_locator_1 = FixedLocator(numpy.array(minorticks))

        # 确定 Y 轴的 MinorFormatter
        def y_minor_formatter_1(num, pos=None):
            return str(num / 1.0) + '0'

        y_minor_formatter_1 = FuncFormatter(y_minor_formatter_1)

        # 设定 X 轴的 Locator 和 Formatter
        yaxis_1.set_minor_locator(y_minor_locator_1)
        yaxis_1.set_minor_formatter(y_minor_formatter_1)

        # 设定 Y 轴主要坐标点与辅助坐标点的样式
        for mil in axes_1.get_yticklabels(minor=True):
            mil.set_fontsize(3)

            # 保存图片
            # ==================================================================================================================================================
            # figobj.savefig(figpath, dpi=figdpi, facecolor=figfacecolor,
            # edgecolor=figedgecolor, linewidth=figlinewidth)

    def plot_kline(self, order_book_id, begindate, enddate):
        pdata = self.get_kline_data(order_book_id, begindate, enddate)
        self.plotkline(pdata)
        pyplot.show()


if __name__ == '__main__':
    k = KLine()
    k.plot_kline(order_book_id='601311.XSHG', begindate=20160809, enddate=20160101)
